Skip to content

New pass Reduce variable liveness #3965

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 30 commits into
base: main
Choose a base branch
from

Conversation

mfrancepillois
Copy link
Contributor

@mfrancepillois mfrancepillois commented Apr 18, 2025

Add a new pass to reduce the variable liveness by prefetching data then moving load op closer to use-op.

@mfrancepillois mfrancepillois requested review from whitneywhtsang, etiotto and a team April 18, 2025 10:52
@mfrancepillois

This comment was marked as outdated.

@mfrancepillois mfrancepillois changed the title Add pass: Reduce the register pressure New pass Reduce register pressure Apr 18, 2025
@mfrancepillois mfrancepillois linked an issue Apr 18, 2025 that may be closed by this pull request
@mfrancepillois mfrancepillois changed the title New pass Reduce register pressure [Draft] New pass Reduce register pressure Apr 18, 2025
@mfrancepillois mfrancepillois marked this pull request as draft April 18, 2025 11:40
@mfrancepillois mfrancepillois marked this pull request as ready for review April 18, 2025 16:31
@mfrancepillois mfrancepillois changed the title [Draft] New pass Reduce register pressure New pass Reduce register pressure Apr 18, 2025
@mfrancepillois mfrancepillois changed the title New pass Reduce register pressure New pass Reduce variable liveness Apr 24, 2025
@mfrancepillois mfrancepillois marked this pull request as draft April 30, 2025 16:53
Signed-off-by: Maxime France-Pillois <[email protected]>
@mfrancepillois mfrancepillois marked this pull request as ready for review April 30, 2025 17:46
@etiotto etiotto requested review from alexbaden and chengjunlu May 1, 2025 19:37
Copy link
Contributor

@whitneywhtsang whitneywhtsang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a loop sink pass in IGC. Can you please create an issue for IGC team to investigate why it doesn't catch the case of FA with the shape that gives the most gain?


/// Create a prefetch operation for the given load operation.
static void createPrefetchOp(tt::LoadOp loadOp) {
Operation *op = loadOp.getPtr().getDefiningOp();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when did we check that loadOp.getPtr() is an operation? do we need to add that to isLoadCandidate?
Or should we add the support of when pointer is a region argument?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for noticing. A check has been added to isLoadCandidate.
As the pass adds a prefetch right after the defining op, I'm concerned that adding this prefetch in another region (in the case the load ptr has been defined in another region) could have side effects on the cache (as an early data fetch could mean evincing data that are still needed).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we care about the case that the pointer directly come from function argument?

@chengjunlu
Copy link
Contributor

chengjunlu commented May 6, 2025

It is good to have the reduce variable liveness as the beginning for liveness optimization in the Triton middle end.
This PR looks good to me as the beginning.

The optimization relies on the cache to hold the values that we may reuse in the loop. But the cache system is not fully controllable by the program. The better we can enhance it with the usage of shared local memory and make it some how like RegisterToMem pass for general case.

@etiotto
Copy link
Contributor

etiotto commented May 6, 2025

@mfrancepillois can you do a Triton Benchmark run with this PR to identify improvement (or degradations - hopefully none) in all the microbmks we have ?

@mfrancepillois mfrancepillois marked this pull request as draft May 12, 2025 13:11
Operation *forOp) {
// Only pointer to tensor are considered to be moved
if (!mlir::triton::isTensorPointerType(loadOp.getPtr().getType()))
if (!mlir::triton::isTensorOrTensorPointerType(loadOp.getPtr().getType()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[optional]

Suggested change
if (!mlir::triton::isTensorOrTensorPointerType(loadOp.getPtr().getType()))
if (!mlir::triton::isTensorPointerType(loadOp.getResult().getType()))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would limite the optimization to block pointer loads. That is conservative and I am OK with limiting the pass in this PR. Generally speaking the pass should work for tensor of ptrs as well as block pointers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current pass does handle block pointer AND pointer of tensors (with the condition that the load has an empty mask).

@mfrancepillois
Copy link
Contributor Author

@mfrancepillois can you do a Triton Benchmark run with this PR to identify improvement (or degradations - hopefully none) in all the microbmks we have ?

After a few improvements to this pass (handling multiple users for the loadOp and improving the condition for a loadOp to be elected as movable), CIs have been run: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/15215096477/job/42798609994

For flash-attention, we have the following performance:

image

Other benchmarks do not seem to be significantly impacted by this pass.

@mfrancepillois mfrancepillois marked this pull request as ready for review May 27, 2025 10:52
// each "for loop" given that the liveness of variables may have changed
// as a result of the code, and specifically `LoadOps`, being modified
// by the pass.
Liveness livenessAnalysis(rootOperation);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To reduce compile time we should detect whether the pass made any changes to the code and only rerun the analysis if changes were made.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code has been modified to run the analysis only when needed.

}

Operation *rootOperation = getOperation();
rootOperation->walk([&](scf::ForOp forOp) {
Copy link
Contributor

@etiotto etiotto May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, the pass for now only handles one kind of loop (scf.for). Is OK as a first cut, we might want/need to enhance it to also support while loops in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment has been added to keep track of this.

Copy link
Contributor

@etiotto etiotto left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initial round of code review comments.

#define LARGE_TENSOR_SIZE_THRESHOLD_IN_BYTES \
LARGE_TENSOR_MAJOR_SHAPE_THRESHOLD *LARGE_TENSOR_MINOR_SHAPE_THRESHOLD * 2

static unsigned getSizeInBytes(RankedTensorType &tensorType) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add documentation for this function and the next pls.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] static is unnecessary because these utilities are in an anonymous namespace.

#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "llvm/Support/Debug.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] move in the section where other llvm include headers are "included".

#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "llvm/Support/Debug.h"

#include "intel/include/Analysis/Liveness.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit]" lets try to keep include headers in their sections (all intel headers together, all triton upstream headers together, etc...)


namespace {

#define TOTAL_BLOCK_SIZE_THRESHOLD_IN_BYTES 32768
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest to use C++ static constexpr instead of #defines.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code has been updated this way.

// The variable is considered as a long life span elected for being moved if:
// The live-in variables of the forOp consist in a large amount of bytes and
// The variable defined by `v` is a large tensor (with large amount of element
// in the minor dimenssion) and The variable liveness of `v` expends before
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The -> the

return false;

for (triton::DotOp dot : dotsInFor) {
auto aVals = getLoad(dot.getA());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use static types on LHS pls.

#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}>
module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32} {
tt.func public @matmul_kernel_small_tensor(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
// CHECK-LABEL: tt.func public @matmul_kernel_small_tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove tt.func public here

ttig.prefetch %1 {boundaryCheck = array<i32: 0, 1>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>} : !tt.ptr<tensor<64x256xf16, #dot1>>
%4:2 = scf.for %arg2 = %c0_i32 to %c64_i32 step %c64_i32 iter_args(%arg3 = %cst, %arg4 = %1) -> (tensor<16x256xf32, #dpas>, !tt.ptr<tensor<64x256xf16, #dot1>>) : i32 {
// CHECK: scf.for
// CHECK-NOT: tt.load {{.*}} : !tt.ptr<tensor<16x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK so this test checks that the load for operand A (opIdx==0) is not sinked into the loop. Would be helpful to add a COM to all the tests to briefly explain what each test is designed to cover.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments have been added to describe the goal of each tests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[FA performance] Improve the Q matrix load stategy
4 participants